{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6011adf8",
   "metadata": {},
   "source": [
    "# 10 分钟快速上手 fastNLP torch\n",
    "\n",
    "在这个例子中，我们将使用BERT来解决conll2003数据集中的命名实体识别任务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e166c051",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2022-07-07 10:12:29--  https://data.deepai.org/conll2003.zip\n",
      "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n",
      "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n",
      "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n",
      "  Issued certificate has expired.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 982975 (960K) [application/x-zip-compressed]\n",
      "Saving to: ‘conll2003.zip’\n",
      "\n",
      "conll2003.zip       100%[===================>] 959.94K   653KB/s    in 1.5s    \n",
      "\n",
      "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n",
      "\n",
      "Archive:  conll2003.zip\n",
      "  inflating: conll2003/metadata      \n",
      "  inflating: conll2003/test.txt      \n",
      "  inflating: conll2003/train.txt     \n",
      "  inflating: conll2003/valid.txt     \n"
     ]
    }
   ],
   "source": [
    "# Linux/Mac 下载数据，并解压\n",
    "import platform\n",
    "if platform.system() != \"Windows\":\n",
    "    !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n",
    "    !unzip conll2003.zip -d conll2003\n",
    "# Windows用户请通过复制该url到浏览器下载该数据并解压"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7acbf1f",
   "metadata": {},
   "source": [
    "## 目录\n",
    "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n",
    "- 1. 数据加载\n",
    "- 2. 数据预处理、数据缓存\n",
    "- 3. DataLoader\n",
    "- 4. 模型准备\n",
    "- 5. Trainer的使用\n",
    "- 6. Evaluator的使用\n",
    "- 7. 其它【待补充】\n",
    "    - 7.1 使用多卡进行训练、评测\n",
    "    - 7.2 使用ZeRO优化\n",
    "    - 7.3 通过overfit测试快速验证模型\n",
    "    - 7.4 复杂Monitor的使用\n",
    "    - 7.5 训练过程中，使用不同的测试函数\n",
    "    - 7.6 更有效率的Sampler\n",
    "    - 7.7 保存模型\n",
    "    - 7.8 断点重训\n",
    "    - 7.9 使用huggingface datasets\n",
    "    - 7.10 使用torchmetrics来作为metric\n",
    "    - 7.11 将预测结果写出到文件\n",
    "    - 7.12 混合 dataset 训练\n",
    "    - 7.13 logger的使用\n",
    "    - 7.14 自定义分布式 Metric 。\n",
    "    - 7.15 通过batch_step_fn实现R-Drop"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0657dfba",
   "metadata": {},
   "source": [
    "#### 1.  数据加载\n",
    "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件，文件的格式为[conll格式](https://universaldependencies.org/format.html)，其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程，继承了 Loader 函数后，只需要在实现读取单个文件 _load() 函数即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c557f0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6f59e438",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In total 3 datasets:\n",
      "\ttrain has 14987 instances.\n",
      "\ttest has 3684 instances.\n",
      "\tdev has 3466 instances.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from fastNLP import DataSet, Instance\n",
    "from fastNLP.io import Loader\n",
    "\n",
    "\n",
    "# 继承Loader之后，我们只需要实现其中_load()方法，_load()方法传入一个文件路径，返回一个fastNLP DataSet对象，其目的是读取一个文件。\n",
    "class ConllLoader(Loader):\n",
    "    def _load(self, path):\n",
    "        ds = DataSet()\n",
    "        with open(path, 'r') as f:\n",
    "            segments = []\n",
    "            for line in f:\n",
    "                line = line.strip()\n",
    "                if line == '':  # 如果为空行，说明需要切换到下一句了。\n",
    "                    if segments:\n",
    "                        raw_words = [s[0] for s in segments]\n",
    "                        raw_target = [s[1] for s in segments]\n",
    "                        # 将一个 sample 插入到 DataSet中\n",
    "                        ds.append(Instance(raw_words=raw_words, raw_target=raw_target))  \n",
    "                    segments = []\n",
    "                else:\n",
    "                    parts = line.split()\n",
    "                    assert len(parts)==4\n",
    "                    segments.append([parts[0], parts[-1]])\n",
    "        return ds\n",
    "    \n",
    "\n",
    "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象，该对象相当于将多个 dataset 放置在一起，\n",
    "#  可以方便之后的预处理，DataBundle 支持的接口可以在 ！！！ 查看。\n",
    "data_bundle = ConllLoader().load({\n",
    "    'train': 'conll2003/train.txt',\n",
    "    'test': 'conll2003/test.txt',\n",
    "    'dev': 'conll2003/valid.txt'\n",
    "})\n",
    "\"\"\"\n",
    "也可以通过 ConllLoader().load('conll2003/') 来读取，其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n",
    "'train'、'test'和'dev'的文件，并分别读取将其命名为'train'、'test'和'dev'（如文件夹中同一个关键字出现在了多个文件名中将导致报错，\n",
    "此时请通过dict的方式传入路径信息）。但在我们这里的数据里，没有文件包含dev，所以无法直接使用文件夹读取，转而通过dict的方式传入读取的路径，\n",
    "该dict的key也将作为读取的数据集的名称，value即对应的文件路径。\n",
    "\"\"\"\n",
    "\n",
    "print(data_bundle)  # 打印 data_bundle 可以查看包含的 DataSet \n",
    "# data_bundle.get_dataset('train')  # 可以获取单个 dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57ae314d",
   "metadata": {},
   "source": [
    "#### 2.  数据预处理\n",
    "接下来，我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有：  \n",
    "（1）使用BertTokenizer将文本转换为index；同时记录每个word被bpe之后第一个bpe的index，用于得到word的hidden state；  \n",
    "（2）使用[Vocabulary](../fastNLP)来将raw_target转换为序号。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "96389988",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c3bd41a323c94a41b409d29a5d4079b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:13] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> Save cache to <span style=\"color: #800080; text-decoration-color: #800080\">/remote-home/hyan01/exps/fastNLP/fastN</span> <a href=\"file://../../fastNLP/core/utils/cache_results.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">cache_results.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/utils/cache_results.py#332\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">332</span></a>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #800080; text-decoration-color: #800080\">LP/demo/torch_tutorial/caches/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">c7f74559_cache.pkl.</span>    <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                    </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n",
       "\u001b[2;36m           \u001b[0m         \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m    \u001b[2m                    \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型，更多的预训练模型请直接使用transformers\n",
    "from fastNLP.transformers.torch import BertTokenizer\n",
    "from fastNLP import cache_results, Vocabulary\n",
    "\n",
    "# 使用cache_results来装饰函数，会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中（其中{param_hash_id}是根据\n",
    "#   传递给 process_data 函数参数决定的，因此当函数的参数变化时，会再生成新的缓存文件。如果需要重新生成新的缓存，(a) 可以在调用process_data\n",
    "#   函数时，额外传入一个_refresh=True的参数; 或者（b）删除相应的缓存文件。此外，保存结果时，cache_results默认还会\n",
    "#   记录 process_data 函数源码的hash值，当其源码发生了变动，直接读取缓存会发出警告，以防止在修改预处理代码之后，忘记刷新缓存。）\n",
    "@cache_results('caches/cache.pkl')\n",
    "def process_data(data_bundle, model_name):\n",
    "    tokenizer = BertTokenizer.from_pretrained(model_name)\n",
    "    def bpe(raw_words):\n",
    "        bpes = [tokenizer.cls_token_id]\n",
    "        first = [0]\n",
    "        first_index = 1  # 记录第一个bpe的位置\n",
    "        for word in raw_words:\n",
    "            bpe = tokenizer.encode(word, add_special_tokens=False)\n",
    "            bpes.extend(bpe)\n",
    "            first.append(first_index)\n",
    "            first_index += len(bpe)\n",
    "        bpes.append(tokenizer.sep_token_id)\n",
    "        first.append(first_index)\n",
    "        return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n",
    "    # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数，并且将返回的结果加入到每条数据中。\n",
    "    data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n",
    "    # 对应我们还有 apply_field() 函数，该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n",
    "    #   内容（即不需要用dict包裹了）。此外，我们还提供了 data_bundle.apply() ，传入 apply() 的函数需要支持传入一个Instance对象，\n",
    "    #   更多信息可以参考对应的文档。\n",
    "    \n",
    "    # tag的词表，由于这是词表，所以不需要有padding和unk\n",
    "    tag_vocab = Vocabulary(padding=None, unknown=None)\n",
    "    # 从 train 数据的 raw_target 中获取建立词表\n",
    "    tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n",
    "    # 使用词表将每个 dataset 中的raw_target转为数字，并且将写入到target这个field中\n",
    "    tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n",
    "    \n",
    "    # 可以将 vocabulary 绑定到 data_bundle 上，方便之后使用。\n",
    "    data_bundle.set_vocab(tag_vocab, field_name='target')\n",
    "    \n",
    "    return data_bundle, tokenizer\n",
    "\n",
    "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True)  # 第一次调用耗时较长，第二次调用则会直接读取缓存的文件\n",
    "# data_bundle = process_data(data_bundle, 'bert-base-uncased')  # 由于参数变化，fastNLP 会再次生成新的缓存文件。    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80036fcd",
   "metadata": {},
   "source": [
    "### 3. DataLoader  \n",
    "由于现在的深度学习算法大都基于 mini-batch 进行优化，因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中，不同的 sample 往往长度不一致，需要进行 padding 操作。在fastNLP中，我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ，我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ，Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ，可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "09494695",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP import prepare_dataloader\n",
    "\n",
    "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ，包含了 'train', 'test', 'dev' 三个\n",
    "#   fastNLP.TorchDataLoader 对象。\n",
    "dls = prepare_dataloader(data_bundle, batch_size=24) \n",
    "\n",
    "\n",
    "# fastNLP 将默认尝试对所有 field 都进行 pad ，如果当前 field 是不可 pad 的类型，则不进行pad；如果是可以 pad 的类型\n",
    "#   默认使用 0 进行 pad 。\n",
    "for dl in dls.values():\n",
    "    # 可以通过 set_pad 修改 padding 的行为。\n",
    "    dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
    "    # 如果希望忽略某个 field ，可以通过 set_ignore 方法。\n",
    "    dl.set_ignore('raw_target')\n",
    "    dl.set_pad('target', pad_val=-100)\n",
    "# 另一种设置的方法是，可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n",
    "#  data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n",
    "#  DataSet 也支持这两个方法。\n",
    "# 若此时调用 batch = next(dls['train'])，则 batch 是一个 dict ，其中包含了\n",
    "#  'input_ids': torch.LongTensor([batch_size, max_len])\n",
    "#  'input_len': torch.LongTensor([batch_size])\n",
    "#  'first': torch.LongTensor([batch_size, max_len'])\n",
    "#  'first_len': torch.LongTensor([batch_size])\n",
    "#  'target': torch.LongTensor([batch_size, max_len'-2])\n",
    "#  'raw_words': List[List[str]]  # 因为无法判断，所以 Collator 不会做任何处理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3583df6d",
   "metadata": {},
   "source": [
    "### 4. 模型准备\n",
    "传入给fastNLP的模型，需要有两个特殊的方法``train_step``、``evaluate_step``，前者默认在 fastNLP.Trainer 中进行调用，后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法，则Trainer会直接使用模型的``forward``函数；如果模型没有``evaluate_step``方法，则Evaluator会直接使用模型的``forward``函数。``train_step``方法（或当其不存在时，``forward``方法）的返回值必须为 dict 类型，并且必须包含``loss``这个 key 。\n",
    "\n",
    "此外fastNLP会使用形参名匹配的方式进行参数传递，例如以下模型\n",
    "```python\n",
    "class Model(nn.Module):\n",
    "   def train_step(self, x, y):\n",
    "        return {'loss': (x-y).abs().mean()}\n",
    "```\n",
    "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ，如果没有找到则会报错。有以下的方法可以解决报错\n",
    "- 修改 train_step 的参数为(input_ids, target)，以保证和 DataLoader 返回的 batch 中的 key 匹配\n",
    "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n",
    "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射，train_input_mapping 也可以是一个函数，更多 train_input_mapping 的介绍可以参考文档。\n",
    "\n",
    "``evaluate_step``也是使用同样的匹配方式，前两条解决方法是一致的，第三种解决方案中，需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f131c1a3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:21] </span><span style=\"color: #800000; text-decoration-color: #800000\">WARNING </span> Some weights of the model checkpoint at            <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1490\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1490</span></a>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         bert-base-uncased were not used when initializing  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         BertModel: <span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.bias'</span>,                <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.weight'</span>,      <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.weight'</span>,                     <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.decoder.weight'</span>,                  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.weight'</span>,          <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.bias'</span>,        <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.bias'</span>,            <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.bias'</span><span style=\"font-weight: bold\">]</span>                       <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         - This IS expected if you are initializing         <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         BertModel from the checkpoint of a model trained   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         on another task or with another architecture <span style=\"font-weight: bold\">(</span>e.g. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         initializing a BertForSequenceClassification model <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         from a BertForPreTraining model<span style=\"font-weight: bold\">)</span>.                  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         - This IS NOT expected if you are initializing     <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         BertModel from the checkpoint of a model that you  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         expect to be exactly identical <span style=\"font-weight: bold\">(</span>initializing a     <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         BertForSequenceClassification model from a         <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         BertForSequenceClassification model<span style=\"font-weight: bold\">)</span>.              <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at            \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n",
       "\u001b[2;36m           \u001b[0m         bert-base-uncased were not used when initializing  \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m,                \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m,      \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.seq_relationship.weight'\u001b[0m,                     \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.predictions.decoder.weight'\u001b[0m,                  \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m,          \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m,        \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m,            \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m                       \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         - This IS expected if you are initializing         \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         BertModel from the checkpoint of a model trained   \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         initializing a BertForSequenceClassification model \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         from a BertForPreTraining model\u001b[1m)\u001b[0m.                  \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         - This IS NOT expected if you are initializing     \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         BertModel from the checkpoint of a model that you  \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         expect to be exactly identical \u001b[1m(\u001b[0minitializing a     \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         BertForSequenceClassification model from a         \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         BertForSequenceClassification model\u001b[1m)\u001b[0m.              \u001b[2m                      \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> All the weights of BertModel were initialized from <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1507\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1507</span></a>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         the model checkpoint at bert-base-uncased.         <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         If your task is similar to the task the model of   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         the checkpoint was trained on, you can already use <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         BertModel for predictions without further          <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         training.                                          <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                      </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m          \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n",
       "\u001b[2;36m           \u001b[0m         the model checkpoint at bert-base-uncased.         \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         If your task is similar to the task the model of   \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         the checkpoint was trained on, you can already use \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         BertModel for predictions without further          \u001b[2m                      \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         training.                                          \u001b[2m                      \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from fastNLP.transformers.torch import BertModel\n",
    "from fastNLP import seq_len_to_mask\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class BertNER(nn.Module):\n",
    "    def __init__(self, model_name, num_class, tag_vocab=None):\n",
    "        super().__init__()\n",
    "        self.bert = BertModel.from_pretrained(model_name)\n",
    "        self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n",
    "                                nn.Dropout(0.3),\n",
    "                                nn.Linear(self.bert.config.hidden_size, num_class))\n",
    "        self.tag_vocab = tag_vocab  # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n",
    "        if tag_vocab is not None:\n",
    "            self._init_constrained_transition()\n",
    "    \n",
    "    def forward(self, input_ids, input_len, first):\n",
    "        attention_mask = seq_len_to_mask(input_len)\n",
    "        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
    "        last_hidden_state = outputs.last_hidden_state\n",
    "        first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n",
    "        first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n",
    "        first_bpe_state = first_bpe_state[:, 1:-1]  # 删除 cls 和 sep\n",
    "        \n",
    "        pred = self.mlp(first_bpe_state)\n",
    "        return {'pred': pred}\n",
    "    \n",
    "    def train_step(self, input_ids, input_len, first, target):\n",
    "        pred = self(input_ids, input_len, first)['pred']\n",
    "        loss = F.cross_entropy(pred.transpose(1, 2), target)\n",
    "        return {'loss': loss}\n",
    "    \n",
    "    def evaluate_step(self, input_ids, input_len, first):\n",
    "        pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n",
    "        return {'pred': pred}\n",
    "    \n",
    "    def constrained_decode(self, input_ids, input_len, first, first_len):\n",
    "        # 这个函数在推理时，将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n",
    "        # 本身这个需求可以在 Metric 中实现，这里在模型中实现的目的是为了方便演示：如何在fastNLP中使用不同的评测函数\n",
    "        pred = self(input_ids, input_len, first)['pred']\n",
    "        cons_pred = []\n",
    "        for _pred, _len in zip(pred, first_len):\n",
    "            _pred = _pred[:_len]\n",
    "            tags = [_pred[0].argmax(dim=-1).item()]  # 这里就不考虑第一个位置非法的情况了\n",
    "            for i in range(1, _len):\n",
    "                tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n",
    "            cons_pred.append(torch.LongTensor(tags))\n",
    "        cons_pred = pad_sequence(cons_pred, batch_first=True)\n",
    "        return {'pred': cons_pred}\n",
    "    \n",
    "    def _init_constrained_transition(self):\n",
    "        from fastNLP.modules.torch import allowed_transitions\n",
    "        allowed_trans = allowed_transitions(self.tag_vocab)\n",
    "        transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n",
    "        for s, e in allowed_trans:\n",
    "            transition[s, e] = 0\n",
    "        self.register_buffer('transition', transition)\n",
    "\n",
    "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aeee1e9",
   "metadata": {},
   "source": [
    "### Trainer 的使用\n",
    "fastNLP 的 Trainer 是用于对模型进行训练的部件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f4250f0b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:49:22] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches.              <a href=\"file://../../fastNLP/core/controllers/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/controllers/trainer.py#661\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">661</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #00d75f; text-decoration-color: #00d75f\">+++++++++++++++++++++++++++++ </span><span style=\"font-weight: bold\">Eval. results on Epoch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span><span style=\"font-weight: bold\">, Batch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"color: #00d75f; text-decoration-color: #00d75f\"> +++++++++++++++++++++++++++++</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>,\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n",
       "  \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n",
       "  \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:51:15] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> The best performance for monitor f#<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">f:0</span>.<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">402447</span> was <a href=\"file://../../fastNLP/core/callbacks/progress_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">progress_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/progress_callback.py#37\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">37</span></a>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         achieved in Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Global Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">625</span>. The        <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                       </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         evaluation result:                                <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                       </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>:     <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                       </span>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span><span style=\"font-weight: bold\">}</span>                                         <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                       </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n",
       "\u001b[2;36m           \u001b[0m         achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The        \u001b[2m                       \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         evaluation result:                                \u001b[2m                       \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m:     \u001b[2m                       \u001b[0m\n",
       "\u001b[2;36m           \u001b[0m         \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m                                         \u001b[2m                       \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> Loading best model from buffer with f#f:  <a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">load_best_model_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">115</span></a>\n",
       "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">           </span>         <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span><span style=\"color: #808000; text-decoration-color: #808000\">...</span>                               <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">                               </span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m          \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Loading best model from buffer with f#f:  \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n",
       "\u001b[2;36m           \u001b[0m         \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m                               \u001b[2m                               \u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from torch import optim\n",
    "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n",
    "from fastNLP import SpanFPreRecMetric\n",
    "\n",
    "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n",
    "callbacks = [\n",
    "    LoadBestModelCallback(),   # 用于在训练结束之后加载性能最好的model的权重\n",
    "    TorchWarmupCallback()\n",
    "]    \n",
    "\n",
    "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n",
    "                  evaluate_dataloaders=dls['dev'], \n",
    "                  metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
    "                  n_epochs=1, callbacks=callbacks, \n",
    "                  # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n",
    "                  evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n",
    "                  device=0, monitor='f#f', fp16=False)  # fp16 为 True 的话，将使用 float16 进行训练。\n",
    "trainer.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c600a450",
   "metadata": {},
   "source": [
    "### Evaluator的使用\n",
    "fastNLP中用于评测数据的对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1b19f0ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.390326</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.414741</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.368626</span><span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Evaluator\n",
    "from fastNLP import SpanFPreRecMetric\n",
    "\n",
    "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n",
    "                      metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
    "                      evaluate_input_mapping={'first_len': 'seq_len'}, \n",
    "                      device=0)\n",
    "evaluator.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52f87770",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f723fe399df34917875ad74c2542508c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 如果想评测一下使用 constrained decoding的性能，则可以通过传入 evaluate_fn 指定使用的函数\n",
    "def input_mapping(x):\n",
    "    x['seq_len'] = x['first_len']\n",
    "    return x\n",
    "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n",
    "                      metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n",
    "                      evaluate_fn='constrained_decode',\n",
    "                      # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数，因此\n",
    "                      #   额外重复一下 'first_len': 'first_len'，使得这个参数不会消失。\n",
    "                      evaluate_input_mapping=input_mapping)\n",
    "evaluator.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "419e718b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
